Skip to content

Conversation

@ashors1
Copy link
Contributor

@ashors1 ashors1 commented Sep 29, 2025

What does this PR do ?

Previously, if save_period % val_period != 0, we would simply save most recent k checkpoints. With this change, if fewer than k checkpoints have validation metrics (say, m), we fallback to saving the most recent k-m checkpoints in addition to the m checkpoints with validation metrics.

Add a one line overview of what this PR aims to accomplish.

Issues

closes #1214

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Clearer warnings when checkpoint metrics are missing: “This checkpoint will not be saved as top-k.”
  • Improvements
    • Consistent handling of missing metrics across DPO, GRPO, and SFT training flows.
    • More robust top-k checkpoint selection that safely handles absent metrics without errors.
  • Bug Fixes
    • Prevents unintended clearing of the configured checkpoint metric, preserving user settings and expected behavior.

@ashors1 ashors1 requested a review from terrykong September 29, 2025 21:01
@ashors1 ashors1 requested review from a team as code owners September 29, 2025 21:01
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 29, 2025

📝 Walkthrough

Walkthrough

Checkpointing behavior is updated across DPO, GRPO, and SFT training to stop mutating the configured checkpoint metric when it’s missing and to simplify warning messages. Checkpoint sorting now safely handles missing metrics via dictionary get with default sentinel values, removing KeyError handling and fallback paths.

Changes

Cohort / File(s) Summary
Algorithms: training checkpoint metric handling
nemo_rl/algorithms/dpo.py, nemo_rl/algorithms/grpo.py, nemo_rl/algorithms/sft.py
When the configured checkpoint metric is absent in the save state, warnings are shortened and master_config["checkpointing"]["metric_name"] is no longer set to None. Control flow no longer mutates the metric setting; top-k applicability is only messaged.
Checkpoint sorting logic
nemo_rl/utils/checkpoint.py
Replaced try/except KeyError sorting with safe dict access: uses x[2].get(metric_name, ±inf) for ordering. Removed fallback warnings and metric reset paths. Behavior unchanged when metric_name is None (sort by step).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant Algo as DPO/GRPO/SFT
  participant Ckpt as CheckpointManager

  rect rgba(230,240,255,0.5)
  note over Trainer,Algo: New flow on missing metric
  Trainer->>Algo: train_step()
  Algo->>Ckpt: save_checkpoint(save_state, metric_name)
  alt metric_name set AND metric missing
    Ckpt-->Algo: compute rank using get(metric, ±inf)
    Algo-->>Trainer: Warn "This checkpoint will not be saved as top-k."
    note right of Algo: metric_name is NOT mutated
  else metric present or metric_name None
    Ckpt-->Algo: sort normally (by metric or step)
  end
  end
Loading
sequenceDiagram
  autonumber
  participant Trainer
  participant Algo as DPO/GRPO/SFT
  participant Ckpt as CheckpointManager

  rect rgba(255,240,230,0.5)
  note over Trainer,Algo: Previous flow on missing metric
  Trainer->>Algo: train_step()
  Algo->>Ckpt: save_checkpoint(save_state, metric_name)
  Ckpt--x Algo: KeyError during metric sort
  Algo-->>Trainer: Warn about fallback
  note right of Algo: metric_name set to None (mutated)
  Ckpt-->Algo: fallback to recent-k by step
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Pre-merge checks and finishing touches

❌ Failed checks (3 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The PR title states "fix: fix checkpointing when val_period does not divide save_period" which describes a specific edge case involving validation and save period intervals. However, the actual changes show modifications to checkpoint warning messages and the removal of metric_name nullification logic across multiple files (dpo.py, grpo.py, sft.py), plus changes to checkpoint.py that replace KeyError handling with safe get-based approaches using default values (inf/-inf). The PR objectives indicate the change is about selecting checkpoints when metrics are missing (saving top m with metrics plus k-m most recent without metrics), but the title focuses narrowly on the val_period/save_period divisibility issue rather than the broader checkpoint selection behavior change. While the title references a real aspect of the problem domain, it doesn't clearly convey the main change about how checkpoints are sorted and selected when metrics are missing. Consider revising the title to better reflect the core change in checkpoint selection behavior. A more accurate title might be "fix: handle missing validation metrics in top-k checkpoint selection" or "fix: improve checkpoint sorting when validation metrics are unavailable" which would more clearly communicate that the change addresses how the system handles checkpoints that lack metrics, rather than focusing specifically on the period divisibility edge case.
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning This PR addresses a bug in checkpoint selection when save_period is not divisible by val_period, which is a significant functional change affecting how checkpoints are retained. The changes modify core sorting logic in the CheckpointManager to use sentinel values (-inf/inf) for missing metrics and adjust behavior across three training algorithms (DPO, GRPO, SFT). However, the PR description does not include any test results, performance measurements, or validation that the new behavior works correctly in the scenario it aims to fix (where some checkpoints have metrics and others don't). While there are existing unit tests for checkpointing, none specifically test the new mixed-metric scenario where k > m checkpoints should include both metric-based and recency-based selections, which is the core functionality this PR introduces. The PR should include test results demonstrating that the new checkpoint selection behavior works correctly when save_period % val_period != 0. Specifically, it should show evidence that when fewer than k checkpoints have validation metrics (m < k), the system correctly saves those m checkpoints with metrics plus the k-m most recent checkpoints without metrics. Additionally, given that there's an open review comment identifying a critical bug in the tie-breaking logic for the higher_is_better=True case, test results should confirm the sorting behavior is correct for both branches before merge.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ashors/topk-ckpt

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/grpo.py (1)

1669-1673: Inconsistent warning message in async_grpo_train.

The warning message at line 1671 still states "Saving most recent k checkpoints instead," which differs from the updated message at line 877 in the synchronous grpo_train function ("This checkpoint will not be saved as top-k."). Additionally, line 1673 still mutates master_config["checkpointing"]["metric_name"] by setting it to None, which was removed in the synchronous version.

For consistency with the rest of this PR, update the async path to match the synchronous path:

  • Change the warning message to match line 877
  • Remove the metric_name mutation at line 1673

Apply this diff:

                             warnings.warn(
                                 f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
-                                "Saving most recent k checkpoints instead."
+                                "This checkpoint will not be saved as top-k."
                             )
-                            master_config["checkpointing"]["metric_name"] = None
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 17ea9ab and 36a055d.

📒 Files selected for processing (4)
  • nemo_rl/algorithms/dpo.py (1 hunks)
  • nemo_rl/algorithms/grpo.py (1 hunks)
  • nemo_rl/algorithms/sft.py (1 hunks)
  • nemo_rl/utils/checkpoint.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/utils/checkpoint.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/utils/checkpoint.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (3)
nemo_rl/algorithms/grpo.py (1)

875-878: LGTM: Warning message updated and metric_name mutation removed.

The change correctly updates the warning message and removes the mutation of metric_name, which aligns with the updated checkpoint sorting logic in checkpoint.py that now handles missing metrics using safe .get() calls with sentinel defaults.

nemo_rl/algorithms/dpo.py (1)

629-637: LGTM: Consistent warning and no metric_name mutation.

The change correctly updates the warning message to match the other algorithm files and removes the mutation of metric_name. This is consistent with the updated checkpoint sorting logic that handles missing metrics safely.

nemo_rl/algorithms/sft.py (1)

502-510: LGTM: Consistent warning and no metric_name mutation.

The change correctly updates the warning message to match the other algorithm files and removes the mutation of metric_name. This is consistent with the updated checkpoint sorting logic that handles missing metrics safely.

Signed-off-by: ashors1 <[email protected]>
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashors1 can you add a test that tests the possibility that some steps don't have val metrics?

def test_remove_old_checkpoints(checkpoint_manager, checkpoint_dir):
# Create multiple checkpoints with different loss values
steps = [1, 2, 3, 4, 5, 6]
losses = [0.5, 0.3, 0.7, 0.2, 0.4, 0.8]
for step, loss in zip(steps, losses):
training_info = {"loss": loss}
tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info)
checkpoint_manager.finalize_checkpoint(tmp_dir)
# Check if only top-k checkpoints are kept
remaining_dirs = list(checkpoint_dir.glob("step_*"))
assert (
len(remaining_dirs) == checkpoint_manager.keep_top_k + 1
) # +1 because we exclude the latest
# Verify the remaining checkpoints are the ones with lowest loss
remaining_losses = []
for dir_path in remaining_dirs:
with open(dir_path / "training_info.json", "r") as f:
metadata = json.load(f)
remaining_losses.append(metadata["loss"])
assert sorted(remaining_losses) == sorted(losses)[
: checkpoint_manager.keep_top_k
] + [0.8] # exclude latest
def test_remove_old_checkpoints_topk_bias_recent_if_equal(
checkpoint_manager, checkpoint_dir
):
# Create multiple checkpoints with the same loss value
# Create multiple checkpoints with the same loss value
steps = [1, 2, 3, 4, 10, 12]
losses = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] # All checkpoints have the same loss
for step, loss in zip(steps, losses):
training_info = {"loss": loss}
tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info)
checkpoint_manager.finalize_checkpoint(tmp_dir)
# Check if only top-k checkpoints are kept
remaining_dirs = list(checkpoint_dir.glob("step_*"))
assert (
len(remaining_dirs) == checkpoint_manager.keep_top_k
) # +1 because we exclude the latest
# When all losses are equal, the most recent checkpoints should be kept
# (excluding the latest which is always kept)
remaining_steps = []
for dir_path in remaining_dirs:
step_num = int(dir_path.name.split("_")[1])
remaining_steps.append(step_num)
# Should keep the most recent checkpoints (highest step numbers)
expected_steps = sorted(steps)[-checkpoint_manager.keep_top_k :]
assert sorted(remaining_steps) == sorted(expected_steps)

What happens if save_period and val_period are not divisible, does that mean some metrics will have inf and -inf and get pruned out? Correct me if I've read that wrong.

I think regarding the top_k to save, we should have a test that guards that latest ckpt even if it doesn't have a val metric

@ashors1
Copy link
Contributor Author

ashors1 commented Sep 30, 2025

@ashors1 can you add a test that tests the possibility that some steps don't have val metrics?

Will do

What happens if save_period and val_period are not divisible

If we have some checkpoints which don't have an associated val metric, all of those checkpoints should get the same default metric value (inf or -inf), in which case we sort the checkpoints by step number. So we first select the top-k checkpoints using val metrics, and fall back to step number when we don't have enough checkpioints with val metrics to populate the top-k

Signed-off-by: ashors1 <[email protected]>
@ashors1 ashors1 requested a review from a team as a code owner September 30, 2025 21:20
Signed-off-by: ashors1 <[email protected]>
@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Oct 1, 2025
@terrykong terrykong enabled auto-merge (squash) October 1, 2025 04:49
@terrykong terrykong merged commit d82ca75 into main Oct 1, 2025
54 of 58 checks passed
@terrykong terrykong deleted the ashors/topk-ckpt branch October 1, 2025 18:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unexepected behavior with save_period=1 and val_period=10

3 participants